{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CRF"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "基于sklearn_crfsuite NER系统搭建，本例来自于sklearn_crfsuite官方tutorial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入相关库\n",
    "import nltk\n",
    "import sklearn\n",
    "import scipy.stats\n",
    "from sklearn.metrics import make_scorer\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "\n",
    "import sklearn_crfsuite\n",
    "from sklearn_crfsuite import scorers\n",
    "from sklearn_crfsuite import metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package conll2002 to\n",
      "[nltk_data]     C:\\Users\\92070\\AppData\\Roaming\\nltk_data...\n",
      "[nltk_data]   Unzipping corpora\\conll2002.zip.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 基于NLTK下载示例数据集\n",
    "nltk.download('conll2002')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置训练和测试样本\n",
    "train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))\n",
    "test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Melbourne', 'NP', 'B-LOC'),\n",
       " ('(', 'Fpa', 'O'),\n",
       " ('Australia', 'NP', 'B-LOC'),\n",
       " (')', 'Fpt', 'O'),\n",
       " (',', 'Fc', 'O'),\n",
       " ('25', 'Z', 'O'),\n",
       " ('may', 'NC', 'O'),\n",
       " ('(', 'Fpa', 'O'),\n",
       " ('EFE', 'NC', 'B-ORG'),\n",
       " (')', 'Fpt', 'O'),\n",
       " ('.', 'Fp', 'O')]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_sents[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 单词转化为数值特征\n",
    "def word2features(sent, i):\n",
    "    word = sent[i][0]\n",
    "    postag = sent[i][1]\n",
    "\n",
    "    features = {\n",
    "        'bias': 1.0,\n",
    "        'word.lower()': word.lower(),\n",
    "        'word[-3:]': word[-3:],\n",
    "        'word[-2:]': word[-2:],\n",
    "        'word.isupper()': word.isupper(),\n",
    "        'word.istitle()': word.istitle(),\n",
    "        'word.isdigit()': word.isdigit(),\n",
    "        'postag': postag,\n",
    "        'postag[:2]': postag[:2],\n",
    "    }\n",
    "    if i > 0:\n",
    "        word1 = sent[i-1][0]\n",
    "        postag1 = sent[i-1][1]\n",
    "        features.update({\n",
    "            '-1:word.lower()': word1.lower(),\n",
    "            '-1:word.istitle()': word1.istitle(),\n",
    "            '-1:word.isupper()': word1.isupper(),\n",
    "            '-1:postag': postag1,\n",
    "            '-1:postag[:2]': postag1[:2],\n",
    "        })\n",
    "    else:\n",
    "        features['BOS'] = True\n",
    "\n",
    "    if i < len(sent)-1:\n",
    "        word1 = sent[i+1][0]\n",
    "        postag1 = sent[i+1][1]\n",
    "        features.update({\n",
    "            '+1:word.lower()': word1.lower(),\n",
    "            '+1:word.istitle()': word1.istitle(),\n",
    "            '+1:word.isupper()': word1.isupper(),\n",
    "            '+1:postag': postag1,\n",
    "            '+1:postag[:2]': postag1[:2],\n",
    "        })\n",
    "    else:\n",
    "        features['EOS'] = True\n",
    "\n",
    "    return features\n",
    "\n",
    "\n",
    "def sent2features(sent):\n",
    "    return [word2features(sent, i) for i in range(len(sent))]\n",
    "\n",
    "def sent2labels(sent):\n",
    "    return [label for token, postag, label in sent]\n",
    "\n",
    "def sent2tokens(sent):\n",
    "    return [token for token, postag, label in sent]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'bias': 1.0,\n",
       " 'word.lower()': 'melbourne',\n",
       " 'word[-3:]': 'rne',\n",
       " 'word[-2:]': 'ne',\n",
       " 'word.isupper()': False,\n",
       " 'word.istitle()': True,\n",
       " 'word.isdigit()': False,\n",
       " 'postag': 'NP',\n",
       " 'postag[:2]': 'NP',\n",
       " 'BOS': True,\n",
       " '+1:word.lower()': '(',\n",
       " '+1:word.istitle()': False,\n",
       " '+1:word.isupper()': False,\n",
       " '+1:postag': 'Fpa',\n",
       " '+1:postag[:2]': 'Fp'}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sent2features(train_sents[0])[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构造训练集和测试集\n",
    "X_train = [sent2features(s) for s in train_sents]\n",
    "y_train = [sent2labels(s) for s in train_sents]\n",
    "\n",
    "X_test = [sent2features(s) for s in test_sents]\n",
    "y_test = [sent2labels(s) for s in test_sents]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8323 1517\n"
     ]
    }
   ],
   "source": [
    "print(len(X_train), len(X_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7964686316443963"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 创建CRF模型实例\n",
    "crf = sklearn_crfsuite.CRF(\n",
    "    algorithm='lbfgs',\n",
    "    c1=0.1,\n",
    "    c2=0.1,\n",
    "    max_iterations=100,\n",
    "    all_possible_transitions=True\n",
    ")\n",
    "# 模型训练\n",
    "crf.fit(X_train, y_train)\n",
    "# 类别标签\n",
    "labels = list(crf.classes_)\n",
    "labels.remove('O')\n",
    "# 模型预测\n",
    "y_pred = crf.predict(X_test)\n",
    "# 计算F1得分\n",
    "metrics.flat_f1_score(y_test, y_pred,\n",
    "                      average='weighted', labels=labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       B-LOC      0.810     0.784     0.797      1084\n",
      "       I-LOC      0.690     0.637     0.662       325\n",
      "      B-MISC      0.731     0.569     0.640       339\n",
      "      I-MISC      0.699     0.589     0.639       557\n",
      "       B-ORG      0.807     0.832     0.820      1400\n",
      "       I-ORG      0.852     0.786     0.818      1104\n",
      "       B-PER      0.850     0.884     0.867       735\n",
      "       I-PER      0.893     0.943     0.917       634\n",
      "\n",
      "   micro avg      0.813     0.787     0.799      6178\n",
      "   macro avg      0.791     0.753     0.770      6178\n",
      "weighted avg      0.809     0.787     0.796      6178\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 打印B和I组的模型结果\n",
    "sorted_labels = sorted(\n",
    "    labels,\n",
    "    key=lambda name: (name[1:], name[0])\n",
    ")\n",
    "print(metrics.flat_classification_report(\n",
    "    y_test, y_pred, labels=sorted_labels, digits=3\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
